#!/usr/bin/python3
# -*- coding:UTF-8 -*-

import AutoExecUtils
import os
import traceback
import datetime
import argparse
import re
import json
from bson.json_util import dumps, loads


def usage():
    pname = os.path.basename(__file__)
    print(pname + " --outputfile <path> ")
    exit(1)


def getObjCatPkDef(objCat, customPkDef):
    pkDefs = {
        "INS": ["MGMT_IP", "PORT"],
        "DB": ["PRIMARY_IP", "PORT", "NAME"],
        "CLUSTER": ["PRIMARY_IP", "PORT", "NAME"],
        "ADMINSET": ["UNIQUE_NAME"],
        "DBINS": ["MGMT_IP", "PORT", "INSTANCE_NAME"],
        "OS": ["MGMT_IP"],
        "HOST": ["MGMT_IP", "BOARD_SERIAL"],
        "NETDEV": ["MGMT_IP", "SN"],
        "SECDEV": ["MGMT_IP", "SN"],
        "VIRTUALIZED": ["MGMT_IP"],
        "SWITCH": ["MGMT_IP", "SN"],
        "FIREWALL": ["MGMT_IP", "SN"],
        "LOADBALANCER": ["MGMT_IP", "SN"],
        "STORAGE": ["MGMT_IP", "SN"],
        "FCSWITCH": ["MGMT_IP", "SN"],
        "FCDEV": ["MGMT_IP", "SN"],
        "K8S": ["MGMT_IP"],
        "CONTAINER": ["MGMT_IP", "CONTAINER_ID"],
        "UNKNOWN": ["MGMT_IP"],
    }
    fixPkDef = pkDefs.get(objCat)
    if fixPkDef is not None:
        return fixPkDef

    return customPkDef


def getObjCatIndexDef(objCat):
    idxDefs = {
        "DBINS": ["IP", "VIP"],
        "HOST": ["OS_ID", "GUESTOS_UUIDS"],
        "OS": ["OS_ID", "HBA_INTERFACES.WWNN", "HBA_INTERFACES.WWPN"],
        "CLUSTER": ["MEMBER_PEER"],
        "SWITCH": ["DEV_NAME"],
        "STORAGE": ["HBA_INTERFACES.WWNN", "HBA_INTERFACES.WWPN"],
        "FCSWITCH": ["WWNN", "PORTS.WWPN", "LINK_TABLE.PEER_WWPN"],
        "FCDEV": ["WWNN", "PORTS.WWPN", "LINK_TABLE.PEER_WWPN"],
        "ADMINSET": ["UNIQUE_NAME"],
    }
    idxDef = idxDefs.get(objCat, [])
    return idxDef


def createPKIndex(collection, pkDef):
    pkIdx = []
    for field in pkDef:
        pkIdx.append((field, 1))
    collection.create_index(pkIdx, name="idx_pk", unique=True)
    collection.create_index([("_renewtime", 1)], name="idx_ttl", expireAfterSeconds=15811200)
    collection.create_index([("_batch_tag", 1)], name="idx_batchtag")


def createIndex(collection, data, objCat):
    collection.create_index([("_OBJ_TYPE", 1)], name="idx_obj_type")
    if "INDEX_FIELDS" in data:
        for idxField in data["INDEX_FIELDS"]:
            collection.create_index([(idxField, 1)], name="idx_" + idxField.lower())
    else:
        for idxField in getObjCatIndexDef(objCat):
            collection.create_index([(idxField, 1)], name="idx_" + idxField.lower())


# 目的通过PEER信息搜索到应用进程
def saveAppConnData(db, existsCollections, data):
    if data["_OBJ_CATEGORY"] in ("INS", "DBINS") and "OS_ID" in data:
        connInfo = {}
        connInfo["OS_ID"] = data["OS_ID"]
        connInfo["_OBJ_CATEGORY"] = data["_OBJ_CATEGORY"]
        connInfo["_OBJ_TYPE"] = data["_OBJ_TYPE"]
        connInfo["MGMT_IP"] = data["MGMT_IP"]
        connInfo["PORT"] = data["PORT"]

        if "RESOURCE_ID" in data:
            connInfo["RESOURCE_ID"] = data["RESOURCE_ID"]
        else:
            connInfo["RESOURCE_ID"] = None

        if "CONN_INFO" in data:
            connData = data["CONN_INFO"]
            connInfo["BIND"] = connData["BIND"]
            connInfo["PEER"] = connData["PEER"]
            connInfo["LOCAL_PEER"] = connData["LOCAL_PEER"]
        else:
            connInfo["BIND"] = []
            connInfo["PEER"] = []
            connInfo["LOCAL_PEER"] = []

        primaryKey = {"MGMT_IP": connInfo["MGMT_IP"], "PORT": connInfo["PORT"]}
        collection = db["RELATION_INS_NETCONN"]
        # BIND和PEER都是简单数组，建立索引后，可以使用$in操作符进行配合，结合$elemMatch可以过滤命中的BIND和PEER
        collection.replace_one(primaryKey, connInfo, upsert=True)
        if "RELATION_INS_NETCONN" not in existsCollections:
            collection.create_index([("MGMT_IP", 1), ("PORT", 1)], name="idx_pk")
            collection.create_index([("BIND", 1)], name="idx_bind")
            collection.create_index([("PEER", 1)], name="idx_peer")
            collection.create_index([("_renewtime", 1)], name="idx_ttl", expireAfterSeconds=15811200)
            existsCollections["RELATION_INS_NETCONN"] = 1

        print("INFO: Save connection data success.")
        # del(data['CONN_INFO'])


def saveSwMacTableData(db, existsCollections, data):
    if data["_OBJ_CATEGORY"] == "SWITCH" and "MAC_TABLE" in data:
        macTableInfo = {}
        macTableInfo["_OBJ_CATEGORY"] = data["_OBJ_CATEGORY"]
        macTableInfo["_OBJ_TYPE"] = data["_OBJ_TYPE"]
        macTableInfo["MGMT_IP"] = data["MGMT_IP"]
        macTableInfo["SN"] = data["SN"]
        macTableInfo["DEV_NAME"] = data["DEV_NAME"]
        if "RESOURCE_ID" in data:
            macTableInfo["RESOURCE_ID"] = data["RESOURCE_ID"]
        else:
            macTableInfo["RESOURCE_ID"] = None
        macTableInfo["MAC_TABLE"] = data["MAC_TABLE"]

        primaryKey = {"MGMT_IP": macTableInfo["MGMT_IP"], "SN": macTableInfo["SN"]}
        collection = db["RELATION_NET_MACTABLE"]
        collection.replace_one(primaryKey, macTableInfo, upsert=True)
        if "RELATION_NET_MACTABLE" not in existsCollections:
            collection.create_index([("MGMT_IP", 1), ("SN", 1)], name="idx_pk")
            # MAC_TABLE_MACS索引用于使用$in操作符检索，结合$elemMatch返回匹配的MAC记录
            collection.create_index([("MAC_TABLE.MACS", 1)], name="idx_mac")
            collection.create_index([("_renewtime", 1)], name="idx_ttl", expireAfterSeconds=15811200)
            existsCollections["RELATION_NET_MACTABLE"] = 1

        print("INFO: Save mac table data success.")
        # del data["MAC_TABLE"]


def mergeAndSaveData(existsCollections, table, pkDef, primaryKey, data, objCat):
    collection = db[table]
    oldData = collection.find_one(primaryKey, {"_id": False})
    if oldData is None:
        oldData = data
    else:
        for key in data.keys():
            oldData[key] = data[key]
    collection.replace_one(primaryKey, oldData, upsert=True)
    print("INFO: Save data success.")

    if table not in existsCollections:
        createPKIndex(collection, pkDef)
        createIndex(collection, data, objCat)
        collection.create_index([("_renewtime", 1)], name="idx_ttl", expireAfterSeconds=15811200)
        existsCollections[table] = 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batchtag", default="", help="Json document tag for this batch")
    parser.add_argument("--outputfile", default="", help="Output json file path for node")
    parser.add_argument("--node", default="", help="Execution node json")
    parser.add_argument(
        "toolnames",
        nargs=argparse.REMAINDER,
        help="CMDB data collect custome tool names",
    )
    args = parser.parse_args()
    batchTag = args.batchtag
    outputPath = args.outputfile

    if outputPath is None or outputPath == "":
        outputPath = os.getenv("NODE_OUTPUT_PATH")
    else:
        os.environ["NODE_OUTPUT_PATH"] = outputPath

    if outputPath is None:
        print("ERROR: Must set environment variable NODE_OUTPUT_PATH or defined option --outputfile.\n")
        usage()

    collectionNames = {}
    objTypesMap = {}
    toolNamesMap = {}
    for toolName in args.toolnames:
        matchObj = re.match(r"^#\{.*?\.(.*_\d+)\.", toolName)
        if matchObj:
            toolName = matchObj.group(1)
            toolNamesMap[toolName] = 1
        else:
            toolName = re.sub(r"^#\{.*?\.", "", toolName)
            toolName = re.sub(r"\..*?$", "", toolName)
            toolNamesMap[toolName] = 1
        possibleToolName = ""
        while possibleToolName != toolName:
            possibleToolName = re.sub(r"^.*?\.", "", toolName)
            toolNamesMap[possibleToolName] = 1

    (dbclient, db) = AutoExecUtils.getDB()
    existsCollections = {}
    for collection in db.list_collection_names():
        try:
            collection.index("COLLECT_")
            existsCollections[collection] = 1
        except:
            pass

    execUser = os.getenv("AUTOEXEC_USER")
    try:
        outputData = AutoExecUtils.loadNodeOutput()
        if outputData is None:
            print("WARN: Node output data is empty.")
        else:
            collectData = None
            for key, value in outputData.items():
                if key == "DATA":
                    collectData = value
                else:
                    keyName = re.sub(r"_\d+$", "", key)
                    # svcinspect巡检插件，包含补充CMDB采集数据对象的数据
                    if not key.startswith("cmdbcollect/") and not key.startswith("svcinspect/") and not key in toolNamesMap and not keyName in toolNamesMap:
                        continue

                    if key.startswith("cmdbcollect/savedata"):
                        continue

                    collectData = value.get("DATA")
                    if collectData is None:
                        print("WARN: Plugin {} did not return collect data.".format(key))
                        continue

            if collectData is None:
                print("ERROR: Can not find any cmdb collect data.")
                exit(-1)

            for data in collectData:
                # 检查数据，必须包含PK、_OBJ_CATEGORY
                objCat = data.get("_OBJ_CATEGORY")
                objType = data.get("_OBJ_TYPE")
                pkDef = data.get("PK")
                pkDef = getObjCatPkDef(objCat, pkDef)

                print("INFO: Try to validate data PID:{} OBJ_TYPE:{} IP:{} PORT:{} ...".format(data.get("PID"), objType, data.get("MGMT_IP"), data.get("PORT")))
                isMalformData = 0
                if pkDef is None:
                    isMalformData = 1
                    print("WARN: Data not defined PK.")
                if objCat is None:
                    isMalformData = 1
                    print("WARN: Data not defined _OBJ_CATEGORY.")
                if isMalformData == 1:
                    print(json.dumps(data))
                    continue

                # 计算mongodb的collection名称，优先使用_OBJ_TYPE作为表名，如果没有_OBJ_TYPE则使用_OBJ_CATEGORY
                typeName = objCat
                if objType is not None:
                    typeName = objType

                table = "COLLECT_" + objCat
                collectionNames[table] = 1
                objTypesMap["%s/%s" % (objCat, typeName)] = 1

                pkInvalid = False
                # 根据PK的定义生成Primary Key filter
                primaryKey = {}
                for pKey in pkDef:
                    if pKey in data:
                        pVal = data[pKey]
                        primaryKey[pKey] = data[pKey]
                        if pVal is None or pVal == "":
                            # pkInvalid = True
                            print("WARN: {} PK attribute:{} is empty.".format(typeName, pKey))
                    else:
                        primaryKey[pKey] = None
                        pkInvalid = True
                        print("WARN: {} not contain PK attribute:{}.".format(typeName, pKey))

                if pkInvalid:
                    continue

                pkJson = json.dumps(primaryKey)

                try:
                    print("INFO: Begin save data {}:{} ...".format(typeName, pkJson))
                    currentTime = datetime.datetime.utcnow()
                    data["_updatetime"] = currentTime
                    data["STATE"] = "InUse"
                    data["_execuser"] = execUser
                    data["_renewtime"] = currentTime
                    data["_batch_tag"] = batchTag

                    if "_OBJ_CATEGORY" in data:
                        saveAppConnData(db, existsCollections, data)
                        saveSwMacTableData(db, existsCollections, data)
                        pkDef = getObjCatPkDef(objCat, pkDef)
                    print("INFO: Save connection information success.")

                    mergeAndSaveData(existsCollections, table, pkDef, primaryKey, data, objCat)

                    print("INFO: Save data success.\n".format(typeName, pkJson))
                except Exception as ex:
                    print("ERROR: Save data for {}({}) failed, {}".format(typeName, pkJson, ex))
                    traceback.print_exc()
        AutoExecUtils.saveOutput({"collections": collectionNames, "objTypesMap": objTypesMap})
    except Exception as ex:
        print("ERROR: Unknow Error, {}".format(traceback.format_exc()))
        exit(-1)
    finally:
        if dbclient is not None:
            dbclient.close()
